{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import h5py\n",
    "import scipy.io as io\n",
    "import PIL.Image as Image\n",
    "import numpy as np\n",
    "import os\n",
    "import glob\n",
    "from matplotlib import pyplot as plt\n",
    "from scipy.ndimage.filters import gaussian_filter \n",
    "import scipy\n",
    "import json\n",
    "import torchvision.transforms.functional as F\n",
    "from matplotlib import cm as CM\n",
    "from image import *\n",
    "from model import CSRNet\n",
    "import torch\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from torchvision import datasets, transforms\n",
    "transform=transforms.Compose([\n",
    "                       transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                                     std=[0.229, 0.224, 0.225]),\n",
    "                   ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "root = '/home/leeyh/Downloads/Shanghai/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#now generate the ShanghaiA's ground truth\n",
    "part_A_train = os.path.join(root,'part_A_final/train_data','images')\n",
    "part_A_test = os.path.join(root,'part_A_final/test_data','images')\n",
    "part_B_train = os.path.join(root,'part_B_final/train_data','images')\n",
    "part_B_test = os.path.join(root,'part_B_final/test_data','images')\n",
    "path_sets = [part_A_test]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "img_paths = []\n",
    "for path in path_sets:\n",
    "    for img_path in glob.glob(os.path.join(path, '*.jpg')):\n",
    "        img_paths.append(img_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model = CSRNet()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model = model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "checkpoint = torch.load('model_best.pth.tar')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.load_state_dict(checkpoint['state_dict'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 15.50390625\n",
      "1 60.9075317383\n",
      "2 220.511169434\n",
      "3 239.312469482\n",
      "4 252.252349854\n",
      "5 272.965286255\n",
      "6 457.101577759\n",
      "7 651.92250061\n",
      "8 681.363113403\n",
      "9 785.472061157\n",
      "10 838.996322632\n",
      "11 1012.11277771\n",
      "12 1073.18791199\n",
      "13 1074.72886658\n",
      "14 1139.53701782\n",
      "15 1201.55630493\n",
      "16 1316.97366333\n",
      "17 1447.83328247\n",
      "18 1578.5967865\n",
      "19 1622.11299896\n",
      "20 1650.6510849\n",
      "21 1756.23152924\n",
      "22 1810.32888031\n",
      "23 1816.16628265\n",
      "24 1864.3579483\n",
      "25 1886.67713928\n",
      "26 1942.00212097\n",
      "27 1995.17939758\n",
      "28 2031.85020447\n",
      "29 2236.5879364\n",
      "30 2265.58026123\n",
      "31 2272.81892395\n",
      "32 2333.38215637\n",
      "33 2489.37367249\n",
      "34 2560.18891907\n",
      "35 2580.02906799\n",
      "36 2588.45735168\n",
      "37 2661.36177063\n",
      "38 2788.47312927\n",
      "39 2900.07542419\n",
      "40 2900.83190918\n",
      "41 2976.19485474\n",
      "42 2995.64665985\n",
      "43 3058.02416229\n",
      "44 3103.91133881\n",
      "45 3241.52921295\n",
      "46 3906.06101227\n",
      "47 3918.4709549\n",
      "48 3948.23314667\n",
      "49 3953.28383636\n",
      "50 3989.9439621\n",
      "51 4093.59087372\n",
      "52 4180.21788788\n",
      "53 4289.41963959\n",
      "54 4291.79859161\n",
      "55 4302.33109283\n",
      "56 4339.80475616\n",
      "57 4543.3482132\n",
      "58 4626.09952545\n",
      "59 4684.85929108\n",
      "60 4713.21773529\n",
      "61 4801.95433807\n",
      "62 5371.40599823\n",
      "63 5430.06401062\n",
      "64 5463.51008606\n",
      "65 5517.62242126\n",
      "66 5531.55604553\n",
      "67 5531.86631775\n",
      "68 5547.58416748\n",
      "69 5586.52706909\n",
      "70 5588.99980164\n",
      "71 5711.29048157\n",
      "72 5732.70922852\n",
      "73 5764.59197998\n",
      "74 5799.78912354\n",
      "75 6082.0055542\n",
      "76 6110.95211792\n",
      "77 6141.81124878\n",
      "78 6156.64276886\n",
      "79 6169.39850616\n",
      "80 6193.0460434\n",
      "81 6232.38686371\n",
      "82 6257.28891754\n",
      "83 6367.4381485\n",
      "84 6411.16867828\n",
      "85 6790.66916656\n",
      "86 6935.87509918\n",
      "87 7059.09088898\n",
      "88 7105.02100372\n",
      "89 7138.09949493\n",
      "90 7177.44655609\n",
      "91 7284.53981781\n",
      "92 7363.71459198\n",
      "93 7486.82701874\n",
      "94 7518.29531097\n",
      "95 7519.01638031\n",
      "96 7553.38800049\n",
      "97 7622.98153687\n",
      "98 7635.38021851\n",
      "99 7692.21868896\n",
      "100 7695.25094604\n",
      "101 7711.97100067\n",
      "102 7772.21665192\n",
      "103 7800.12631989\n",
      "104 7804.24352264\n",
      "105 7861.25572968\n",
      "106 7887.95476532\n",
      "107 7913.09929657\n",
      "108 7966.69908905\n",
      "109 7977.07558441\n",
      "110 7978.10308075\n",
      "111 8035.75792694\n",
      "112 8099.88143158\n",
      "113 8713.50301361\n",
      "114 8755.26369476\n",
      "115 8764.1200943\n",
      "116 8807.97383881\n",
      "117 8831.98430634\n",
      "118 8899.79549408\n",
      "119 9223.38906097\n",
      "120 9325.20153046\n",
      "121 9335.49025726\n",
      "122 9391.73867035\n",
      "123 9429.91559601\n",
      "124 9475.48241425\n",
      "125 9504.1687851\n",
      "126 9571.56279755\n",
      "127 9621.32411957\n",
      "128 9771.76332855\n",
      "129 9825.34178925\n",
      "130 9829.55644226\n",
      "131 9852.76200867\n",
      "132 9933.64692688\n",
      "133 9962.82582092\n",
      "134 10008.0011597\n",
      "135 10096.1471558\n",
      "136 10133.0110474\n",
      "137 10251.9801483\n",
      "138 10256.5829926\n",
      "139 10268.2446671\n",
      "140 10314.4458389\n",
      "141 10404.9172134\n",
      "142 10450.5001602\n",
      "143 10454.2003555\n",
      "144 10460.2950211\n",
      "145 10468.630188\n",
      "146 10513.0866699\n",
      "147 10540.4085999\n",
      "148 10969.8531189\n",
      "149 11258.8591614\n",
      "150 11272.8908997\n",
      "151 11405.8752747\n",
      "152 11432.7355957\n",
      "153 11438.5995789\n",
      "154 11497.5692749\n",
      "155 11842.559082\n",
      "156 11955.1886826\n",
      "157 12029.658989\n",
      "158 12082.984642\n",
      "159 12297.4541245\n",
      "160 12355.7072983\n",
      "161 12436.2029648\n",
      "162 12455.2758102\n",
      "163 12750.65905\n",
      "164 12854.170433\n",
      "165 12870.6474228\n",
      "166 12913.027977\n",
      "167 12924.1338921\n",
      "168 12944.4942436\n",
      "169 12986.8365135\n",
      "170 13030.5286522\n",
      "171 13076.5168724\n",
      "172 13136.3384972\n",
      "173 13241.9948387\n",
      "174 13256.391201\n",
      "175 13441.6517601\n",
      "176 13550.5347557\n",
      "177 13579.6995201\n",
      "178 13612.6000328\n",
      "179 13623.0231133\n",
      "180 13667.2316093\n",
      "181 13725.9745598\n",
      "75.4174426362\n"
     ]
    }
   ],
   "source": [
    "mae = 0\n",
    "for i in xrange(len(img_paths)):\n",
    "    img = 255.0 * F.to_tensor(Image.open(img_paths[i]).convert('RGB'))\n",
    "\n",
    "    img[0,:,:]=img[0,:,:]-92.8207477031\n",
    "    img[1,:,:]=img[1,:,:]-95.2757037428\n",
    "    img[2,:,:]=img[2,:,:]-104.877445883\n",
    "    img = img.cuda()\n",
    "    #img = transform(Image.open(img_paths[i]).convert('RGB')).cuda()\n",
    "    gt_file = h5py.File(img_paths[i].replace('.jpg','.h5').replace('images','ground_truth'),'r')\n",
    "    groundtruth = np.asarray(gt_file['density'])\n",
    "    output = model(img.unsqueeze(0))\n",
    "    mae += abs(output.detach().cpu().sum().numpy()-np.sum(groundtruth))\n",
    "    print i,mae\n",
    "print mae/len(img_paths)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
